import traceback
import sys
import io
from functools import partial
from copy import deepcopy
from colorama import Fore, Style
import termcolor

from ProAgent.loggers.logs import logger
from ProAgent.n8n_parser.workflow import n8nPythonWorkflow
from ProAgent.n8n_parser.node import n8nPythonNode
from ProAgent.utils import TestResult, TestDataType, RunTimeStatus, NodeType

from ProAgent.n8n_tester.mock_input import MockInput
from ProAgent.n8n_tester.run_node import run_node, n8nRunningException, anonymous_class

class n8nNodeRunner():
    def __init__(self, node: n8nPythonNode, name_space: dict):
        self.node = node
        self.name_space: dict = name_space

    def __call__(self, input_data):
        """
        Execute the function with the given input data and return the output data.
        
        Args:
            input_data: The input data to be processed by the function.
        
        Returns:
            The output data generated by the function.
        
        Raises:
            n8nRunningException: If an error occurs during the function execution.
        """
        self.node.last_runtime_info.visit_times += 1
        self.node.last_runtime_info.TestDataType = TestDataType.ActionInput
        self.node.last_runtime_info.input_data = input_data
        # print(f"we are in {self.node.get_name()} for {self.node.last_runtime_info.visit_times}'s time")


        input_data_var_name = f"{self.node.get_name()}_input_data"
        output_data_var_name = f"{self.node.get_name()}_output_data"
        self.name_space[input_data_var_name] = input_data
        self.name_space[f"transparent_{self.node.node_meta.node_type.name}"] = partial(anonymous_class,node=self.node)
        node_code = "from typing import List, Dict\n"+"\n".join(self.node.print_self()) + f"\n{output_data_var_name}={self.node.get_name()}({input_data_var_name})"

        error = None
        tb = None

        # local_vars = {}
        # exec(node_code, self.name_space, local_vars)
        # return local_vars[output_data_var_name]

        try:
            local_vars = {}
            exec(node_code, self.name_space, local_vars)
            output_data = local_vars[output_data_var_name]
            self.node.last_runtime_info.output_data = deepcopy(output_data)
            self.node.last_runtime_info.runtime_status = RunTimeStatus.FunctionExecuteSuccess
            return output_data
        except Exception as e:
            # import pdb; pdb.set_trace()
            print(termcolor.colored(traceback.format_exc() + str(e)))
            tb = traceback.extract_tb(e.__traceback__)
            
            error = n8nRunningException(e)
            error.error_message = f"{type(e).__name__}: " + str(e)
            self.node.last_runtime_info.runtime_status = RunTimeStatus.ErrorRaisedHere

        assert error != None
        code_split = node_code.split("\n")
        for k, line in enumerate(code_split):
            if ".run" in line:
                line_no = k + 1
                break

        error_codes = ["--> "+code_split[line_no - 1]]
        if line_no < len(code_split) and code_split[line_no].strip() != "":
            error_codes.append( "    "+ code_split[line_no])
        if line_no - 2 >= 0 and code_split[line_no - 2].strip() != "":
            error_codes = ["    "+ code_split[line_no - 2]] + error_codes

        error_codes = [f"In Function: transparent_{self.node.node_meta.node_type.name}"] + error_codes
        error.add_context_stack(error_codes)

        raise error

class n8nWorkflowRunner():
    def __init__(self, workflow_name: str,  workflow: n8nPythonWorkflow, name_space: dict):
        """
        Initializes a new instance of the class.

        Args:
            workflow_name (str): The name of the workflow.
            workflow (n8nPythonWorkflow): The n8nPythonWorkflow object.
            name_space (dict): The dictionary containing the namespace.

        Returns:
            None
        """
        self.workflow = workflow
        self.workflow_name = workflow_name
        self.name_space: dict = name_space


    def __call__(self, input_data):
        """
        Executes the workflow with the given input data.
        
        Args:
            input_data: The input data to be passed to the workflow.
        
        Returns:
            The output data generated by the workflow.
        
        Raises:
            n8nRunningException: If an error occurs during the execution of the workflow.
        """
        self.workflow.last_runtime_info.visit_times += 1
        self.workflow.last_runtime_info.TestDataType = TestDataType.ActionInput
        self.workflow.last_runtime_info.input_data = input_data
        # print(f"we are in {self.workflow_name} for {self.workflow.last_runtime_info.visit_times}'s time")

        workflow_code = self.workflow.implement_code

        input_data_var_name = f"{self.workflow_name}_input_data"
        output_data_var_name = f"{self.workflow_name}_output_data"
        self.name_space[input_data_var_name] = input_data
        workflow_code = workflow_code + f"\n{output_data_var_name}={self.workflow_name}({input_data_var_name})"
        error = None
        tb = None
        # print(self.name_space.keys())

        # local_vars = {}
        # exec(workflow_code, self.name_space, local_vars)
        # return local_vars[output_data_var_name]

        try:
            local_vars = {}
            exec(workflow_code, self.name_space, local_vars)
            output_data = local_vars[output_data_var_name]
            self.workflow.last_runtime_info.output_data = deepcopy(output_data)
            self.workflow.last_runtime_info.runtime_status = RunTimeStatus.FunctionExecuteSuccess
            return output_data
        except Exception as e:
            # import pdb; pdb.set_trace()
            tb = traceback.extract_tb(e.__traceback__)
            
            if isinstance(e, n8nRunningException):
                error = e
                self.workflow.last_runtime_info.runtime_status = RunTimeStatus.ErrorRaisedInner
            else:
                error = n8nRunningException(e)
                error.error_message = f"{type(e).__name__}: " + str(e)
                self.workflow.last_runtime_info.runtime_status = RunTimeStatus.ErrorRaisedHere

        assert error != None
        code_split = workflow_code.split("\n")
        frame = tb[-1]
        counter = 2
        for cont in tb:
            if cont.filename == "<string>":
                # frame = cont
                counter -= 1
            if counter == 0:
                frame = cont
                break
        # print(f"frame.lineno: {frame.lineno}, len(code_split): {len(code_split)}")

        if frame.lineno > 0:
            error_codes = ["--> "+code_split[frame.lineno - 1]]
            if frame.lineno < len(code_split) and code_split[frame.lineno].strip() != "":
                error_codes.append( "    "+ code_split[frame.lineno])
            if frame.lineno - 2 >= 0 and code_split[frame.lineno - 2].strip() != "":
                error_codes = ["    "+ code_split[frame.lineno - 2]] + error_codes
            error_codes = [f"In Function: {frame.name}"] + error_codes
            error.add_context_stack(error_codes)
        
        # import pdb; pdb.set_trace()
        # add local_var info
        local_var_info = "Note: if there is 'KeyError' in the error message, it may be due to the wrong usage of output data. The output data info may help you: \n[Output Data Info]\n"
        for action_name in self.name_space.keys():
            if action_name not in ['mainWorkflow', 'mainWorkflow_input_data', '__builtins__'] and 'input_data' not in action_name:
                if self.name_space[action_name] and "node" in self.name_space[action_name].__dict__.keys():
                    if self.name_space[action_name].node and "last_runtime_info" in self.name_space[action_name].node.__dict__.keys():
                        if self.name_space[action_name].node.last_runtime_info and "output_data" in self.name_space[action_name].node.last_runtime_info.__dict__.keys():
                            action_output_data = self.name_space[action_name].node.last_runtime_info.output_data
                    # if action_output_data is not None and len(action_output_data[0].keys()) > 0:
                            local_var_info += f"the output data of function `{action_name}` is: `{action_output_data}`\n"
            
        error.add_context_stack([local_var_info])
        raise error


class n8nPythonCodeRunner():
    def __init__(self):
        self.error_stack_str = []
        self.mock_interface = MockInput(
            
        )

    def flash(self, main_workflow: n8nPythonWorkflow,workflows: dict[str, n8nPythonWorkflow], nodes: [n8nPythonNode]):
        """
        Pull the latest n8n-python-code.
        Initializes the `workflows` and `nodes` attributes of the class.

        Parameters:
            main_workflow (n8nPythonWorkflow): The main workflow object.
            workflows (dict[str, n8nPythonWorkflow]): A dictionary containing the workflows.
            nodes ([n8nPythonNode]): A list of nodes.

        Returns:
            None
        """
        self.workflows = {}
        for key, value in workflows.items():
            self.workflows[key] = value
        self.workflows["mainWorkflow"] = main_workflow

        self.nodes = nodes


    def run_code(self):
        """
        1. Initialize the runtime information for all functions.
        2. Execute the current code and modify the information of all accessed nodes.
        """

        for workflow_name, workflow in self.workflows.items():
            workflow.last_runtime_info = TestResult(
                                                    data_type=TestDataType.NoInput,
                                                    runtime_status=RunTimeStatus.DidNotBeenCalled,
                                                )
        for node in self.nodes:
            node.last_runtime_info = TestResult(
                                                    data_type=TestDataType.NoInput,
                                                    runtime_status=RunTimeStatus.DidNotBeenCalled,
                                                )

        trigger_input = None
        for node in self.nodes:
            if node.node_meta.node_type == NodeType.trigger and node.implemented:
                trigger_input = self.mock_interface.get_node_example_input(node, top_k=1)
                node.last_runtime_info = TestResult(
                    data_type=TestDataType.NoInput,
                    runtime_status=RunTimeStatus.TriggerAcivatedSuccess,
                    visit_times=1,
                    output_data=deepcopy(trigger_input),
                )
                break
        if trigger_input == None:
            pass

        name_space = {}
        for node in self.nodes:
            
            name_space[node.get_name()] = n8nNodeRunner(node=node, name_space={})
        for workflow_name, workflow in self.workflows.items():
            name_space[workflow_name] = n8nWorkflowRunner(workflow_name=workflow_name, workflow=workflow, name_space={})
        for key in name_space.keys():
            name_space[key].name_space = name_space
        

        self.error_stack_str = []
        self.std_output = ""
        try:
            name_space["mainWorkflow"](trigger_input)
        except n8nRunningException as e:
            for code_lines in reversed(e.code_stack):
                self.error_stack_str.extend(code_lines)
                self.error_stack_str.append("------------------------")
            self.error_stack_str.append(e.error_message)
        self.error_stack_str = "\n".join(self.error_stack_str)


    def print_clean_code(self, indent = 0):
        '''
            打印 clean code. 无注释。
        '''
        lines = []
        for node in self.nodes:
            lines.extend(node.print_self())
            lines.append("\n\n")
        
        for workflow_name, workflow in self.workflows.items():
            lines.append(workflow.print_self())
            lines.append("\n\n")

        lines = [" "*indent + line for line in lines]
        return "\n".join(lines)
    
    def print_code(self, indent = 0):
        """目前的代码长什么样子？
        """
        lines = []
        for node in self.nodes:
            lines.append("\"\"\"Function param descriptions: ")
            if len(node.params) > 0:
                for k, (key, value) in enumerate(node.params.items()):
                    param_des_lines = value.to_description(prefix_ids=f"{k}", indent=0, max_depth=1)
                    lines.extend(param_des_lines)
            else:
                lines.append("This function doesn't need params")
            lines.append(node.last_runtime_info.to_str())

            # lines.append("Avaliable example inputs for this function")
            # example_inout_pair = mock_interface.get_node_example_input(node, top_k=2)
            # for k, data in enumerate(example_inout_pair):
            #     pass
            lines.append("\"\"\"")
            lines.extend(node.print_self())
            lines.append("\n\n")
        
        
        for workflow_name, workflow in self.workflows.items():
            lines.append("\"\"\"")
            lines.append(workflow.last_runtime_info.to_str())
            lines.append("\"\"\"")

            lines.append(workflow.print_self())
            lines.append("\n\n")


        running_prompt = f"""
The directly running result for now codes with print results are as following:

{self.std_output}
{self.error_stack_str}

You can also see the runnning result for all functions in there comments."""

        lines.append("\"\"\"")
        lines.append(running_prompt)
        lines.append("\"\"\"")

        lines = [" "*indent + line for line in lines]
        return "\n".join(lines)

